import torch
import torch.nn.functional as F

if __name__ == '__main__':
    print(F.one_hot(torch.arange(0, 5) % 3))

    print(F.one_hot(torch.arange(0, 5) % 3, num_classes=5))

    print(F.one_hot(torch.arange(0,6).view(3,2) % 3))
